{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = torch.Tensor([0.2])\n",
    "\n",
    "def RBF_kernel(x,x_,dim):\n",
    "    x = x.view(-1,1,dim)\n",
    "    x_ = x_.view(1,-1,dim)\n",
    "    dist = (x-x_).pow(2).sum(dim=-1)\n",
    "    return torch.exp(-dist/2/sigma.pow(2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GaussianProcess:\n",
    "    \n",
    "    def __init__(self,dim_x,dim_y,var,kernel=None):\n",
    "        self.dim_x = dim_x\n",
    "        self.dim_y = dim_y\n",
    "        self.var = var\n",
    "        if kernel:\n",
    "            self.kernel = kernel\n",
    "        else:\n",
    "            self.kernel = RBF_kernel\n",
    "            \n",
    "    def train(self,x,y):\n",
    "            self.train_size = x.size()[0]\n",
    "            self.train_x = x\n",
    "            self.train_y = y\n",
    "            self.Knn = self.kernel(x,x,self.dim_x)\n",
    "            self.In = torch.eye(self.train_size)\n",
    "            self.G = (self.Knn + self.var*self.In).inverse()\n",
    "            \n",
    "            \n",
    "    def predict_(self,pseudo_x,pred_x):\n",
    "            Kmm = self.kernel(pseudo_x,pseudo_x,self.dim_x)\n",
    "            Kmm_inv = Kmm.inverse()\n",
    "            Knm = self.kernel(self.train_x,pseudo_x,self.dim_x)\n",
    "            Kmn = Knm.t()\n",
    "            Kpm = self.kernel(pred_x,pseudo_x,self.dim_x)\n",
    "            Kmp = Kpm.t()\n",
    "            \n",
    "            _L =  self.Knn - Knm@Kmm_inv@Kmn\n",
    "            L = _L*I\n",
    "            print(L)\n",
    "            \n",
    "            F = (L + self.var*self.In).inverse()\n",
    "            \n",
    "            Q = Kmm + Kmn@F@Knm\n",
    "            Q_inv = Q.inverse()\n",
    "            \n",
    "            #print(Kpm.size(),Q_inv.size(),Kmn.size(),F.size(),self.train_y.size())\n",
    "            \n",
    "            myu_p = Kpm@Q_inv@Kmn@F@self.train_y\n",
    "                    \n",
    "            return myu_p\n",
    "\n",
    "    \n",
    "    def predict(self,pred_x):\n",
    "        Kpp = self.kernel(pred_x,pred_x,self.dim_x)\n",
    "        Kpn = self.kernel(pred_x,self.train_x,self.dim_x)\n",
    "        Knp = Kpn.t()\n",
    "        \n",
    "        myu_p = Kpn@self.G@self.train_y\n",
    "        sigma_p = Kpp - Kpn@self.G@Knp\n",
    "        \n",
    "        return myu_p,sigma_p\n",
    "        \n",
    "    \n",
    "    def distance_measure(self):\n",
    "        Kxn = self.kernel(x,self.train_x)\n",
    "        Knx = Kxn.t()\n",
    "        Kxx = self.kernel(x,self.train_x)\n",
    "        return Kxx - Kxn@self.Knn_inv@Knx\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'a' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-16-1ff5af9afa22>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mA_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mA_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-16-1ff5af9afa22>\u001b[0m in \u001b[0;36mfunc\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m      4\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mA_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'a' is not defined"
     ]
    }
   ],
   "source": [
    "class A:\n",
    "    def __init__(self,b):\n",
    "        a=3\n",
    "        self.b=b\n",
    "    def func(self):\n",
    "        return a,b\n",
    "    \n",
    "A_ = A(2)\n",
    "A_.func()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training data is 11 points in [0,1] inclusive regularly spaced\n",
    "train_x = torch.linspace(0, 1, 10)\n",
    "# True function is sin(2*pi*x) with Gaussian noise\n",
    "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_x = torch.cat([torch.Tensor([a,b]).view(1,2) for a in train_x for b in train_x] ,dim=0  )\n",
    "grid_y = torch.cat([torch.cat( [(a+b).view(1,1),(a*b).view(1,1)],dim=1) for a in train_x for b in train_x] ,dim=0  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "pseudo_x = torch.cat([torch.Tensor([a,b]).view(1,2) for a,b in zip(train_x,train_x)] ,dim=0  )\n",
    "test_x = pseudo_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.0000],\n",
       "        [0.2222, 0.0123],\n",
       "        [0.4444, 0.0494],\n",
       "        [0.6667, 0.1111],\n",
       "        [0.8889, 0.1975],\n",
       "        [1.1111, 0.3086],\n",
       "        [1.3333, 0.4444],\n",
       "        [1.5556, 0.6049],\n",
       "        [1.7778, 0.7901],\n",
       "        [2.0000, 1.0000]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pseudo_y  = torch.cat([torch.cat( [(a+b).view(1,1),(a*b).view(1,1)],dim=1) for a,b in zip(train_x,train_x)] ,dim=0  )\n",
    "pseudo_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 1.6677e-02, -9.0348e-04],\n",
       "         [ 2.1537e-01,  1.1646e-02],\n",
       "         [ 4.4378e-01,  4.9612e-02],\n",
       "         [ 6.5938e-01,  1.1058e-01],\n",
       "         [ 8.8602e-01,  1.9647e-01],\n",
       "         [ 1.1137e+00,  3.1031e-01],\n",
       "         [ 1.3126e+00,  4.3717e-01],\n",
       "         [ 1.5338e+00,  5.9462e-01],\n",
       "         [ 1.8455e+00,  8.2673e-01],\n",
       "         [ 1.8585e+00,  9.1999e-01]]),\n",
       " tensor([[ 6.0444e-02,  2.5160e-03, -2.3574e-03,  1.3102e-03, -5.7346e-04,\n",
       "           1.9375e-04, -3.8712e-05, -3.6486e-06,  1.7500e-05, -3.2173e-06],\n",
       "         [ 2.5160e-03,  3.0043e-02,  9.1201e-03, -3.0400e-03,  8.6908e-04,\n",
       "          -1.4166e-04, -4.1396e-05,  5.2130e-05, -4.1865e-05,  1.7500e-05],\n",
       "         [-2.3578e-03,  9.1194e-03,  2.6938e-02,  8.6989e-03, -2.5032e-03,\n",
       "           7.1120e-04, -1.4443e-04, -7.0003e-06,  5.2124e-05, -3.6450e-06],\n",
       "         [ 1.3104e-03, -3.0401e-03,  8.6988e-03,  2.6539e-02,  8.8351e-03,\n",
       "          -2.5496e-03,  7.4777e-04, -1.4438e-04, -4.1408e-05, -3.8700e-05],\n",
       "         [-5.7343e-04,  8.6908e-04, -2.5027e-03,  8.8365e-03,  2.6249e-02,\n",
       "           8.8735e-03, -2.5487e-03,  7.1149e-04, -1.4167e-04,  1.9380e-04],\n",
       "         [ 1.9370e-04, -1.4181e-04,  7.1128e-04, -2.5488e-03,  8.8736e-03,\n",
       "           2.6250e-02,  8.8377e-03, -2.5021e-03,  8.6943e-04, -5.7323e-04],\n",
       "         [-3.8710e-05, -4.1411e-05, -1.4435e-04,  7.4786e-04, -2.5492e-03,\n",
       "           8.8362e-03,  2.6540e-02,  8.6993e-03, -3.0401e-03,  1.3103e-03],\n",
       "         [-3.6503e-06,  5.2113e-05, -7.0159e-06, -1.4455e-04,  7.1095e-04,\n",
       "          -2.5031e-03,  8.6996e-03,  2.6938e-02,  9.1198e-03, -2.3575e-03],\n",
       "         [ 1.7501e-05, -4.1863e-05,  5.2140e-05, -4.1414e-05, -1.4176e-04,\n",
       "           8.6894e-04, -3.0399e-03,  9.1205e-03,  3.0044e-02,  2.5166e-03],\n",
       "         [-3.2178e-06,  1.7500e-05, -3.6463e-06, -3.8713e-05,  1.9374e-04,\n",
       "          -5.7349e-04,  1.3102e-03, -2.3577e-03,  2.5159e-03,  6.0444e-02]]))"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dim_x = 2\n",
    "dim_y = 2\n",
    "var = 0.1\n",
    "sgp = SparseGaussianProcess(dim_x,dim_y,var)\n",
    "sgp.train(grid_x,grid_y)\n",
    "sgp.predict_(test_x)\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
